{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a28ecfc6",
   "metadata": {},
   "source": [
    "# Knowledge distillation implementation\n",
    "\n",
    "Read about 5 more techniques for model compression here: [Machine Learning Model Compression: A Critical Step Towards Efficient Deep Learning](https://www.dailydoseofds.com/model-compression-a-critical-step-towards-efficient-machine-learning)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "361b163a",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "48a9472e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from time import time\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4b1a8db",
   "metadata": {},
   "source": [
    "# Load the MNIST dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "49473790",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])\n",
    "trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
    "trainloader = DataLoader(trainset, batch_size=64, shuffle=True)\n",
    "\n",
    "testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
    "testloader = DataLoader(testset, batch_size=64, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b00f58e",
   "metadata": {},
   "source": [
    "# Knowledge Distillation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d41ef7bf",
   "metadata": {},
   "source": [
    "## Teacher Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "13b9dc0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TeacherNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(TeacherNet, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, 5)\n",
    "        self.pool = nn.MaxPool2d(5, 5)\n",
    "        self.fc1 = nn.Linear(32 * 4 * 4, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = self.pool(x)\n",
    "        x = x.view(x.size(0), -1)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94397b81",
   "metadata": {},
   "source": [
    "## Evaluation function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4797e7bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model):\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for data in testloader:\n",
    "            inputs, labels = data\n",
    "            outputs = model(inputs)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()\n",
    "    return correct / total"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f775e8d",
   "metadata": {},
   "source": [
    "## Initialize and train the teacher model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f6d78f6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "teacher_model = TeacherNet()\n",
    "teacher_optimizer = optim.Adam(teacher_model.parameters(), lr=0.001)\n",
    "teacher_criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "44cf55cb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Loss: 0.23366064861265898, Accuracy: 97.60%\n",
      "Epoch 2, Loss: 0.07699692965661889, Accuracy: 98.00%\n",
      "Epoch 3, Loss: 0.058064278137973394, Accuracy: 98.44%\n",
      "Epoch 4, Loss: 0.04937064894107677, Accuracy: 98.24%\n",
      "Epoch 5, Loss: 0.04162352114517703, Accuracy: 98.53%\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(5):\n",
    "    teacher_model.train()\n",
    "    running_loss = 0.0\n",
    "    \n",
    "    for data in trainloader:\n",
    "        inputs, labels = data\n",
    "        teacher_optimizer.zero_grad()\n",
    "        outputs = teacher_model(inputs)\n",
    "        loss = teacher_criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        teacher_optimizer.step()\n",
    "        \n",
    "        running_loss += loss.item()\n",
    "        \n",
    "    teacher_accuracy = evaluate(teacher_model)\n",
    "        \n",
    "    print(f\"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}, Accuracy: {teacher_accuracy * 100:.2f}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d06c0fa1",
   "metadata": {},
   "source": [
    "## Student Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f56ad7ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "class StudentNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(StudentNet, self).__init__()\n",
    "        self.fc1 = nn.Linear(28 * 28, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(x.size(0), -1)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.fc2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "497fa2b1",
   "metadata": {},
   "source": [
    "## Initialize and train the teacher model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d1acc58d",
   "metadata": {},
   "outputs": [],
   "source": [
    "student_model = StudentNet()\n",
    "student_optimizer = optim.Adam(student_model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c925045b",
   "metadata": {},
   "source": [
    "## Loss function (KL Divergence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4aead620",
   "metadata": {},
   "outputs": [],
   "source": [
    "def knowledge_distillation_loss(student_logits, teacher_logits):\n",
    "    p_teacher = F.softmax(teacher_logits , dim=1)\n",
    "    p_student = F.log_softmax(student_logits, dim=1)\n",
    "    loss = F.kl_div(p_student, p_teacher, reduction='batchmean')\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d7037a37",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Loss: 1.97617478094473, Accuracy: 93.53%\n",
      "Epoch 2, Loss: 0.9071605966373044, Accuracy: 94.67%\n",
      "Epoch 3, Loss: 0.6211776698874251, Accuracy: 96.30%\n",
      "Epoch 4, Loss: 0.48355193005483244, Accuracy: 96.29%\n",
      "Epoch 5, Loss: 0.4033386060778218, Accuracy: 96.34%\n"
     ]
    }
   ],
   "source": [
    "# Train the student model with knowledge distillation\n",
    "for epoch in range(5):  # You can adjust the number of epochs\n",
    "    student_model.train()\n",
    "    running_loss = 0.0\n",
    "    \n",
    "    for data in trainloader:\n",
    "        inputs, labels = data\n",
    "        student_optimizer.zero_grad()\n",
    "        student_logits = student_model(inputs)\n",
    "        teacher_logits = teacher_model(inputs).detach()  # Detach the teacher's output to avoid backpropagation\n",
    "        loss = knowledge_distillation_loss(student_logits, teacher_logits)\n",
    "        loss.backward()\n",
    "        student_optimizer.step()\n",
    "        \n",
    "        running_loss += loss.item()\n",
    "    \n",
    "    student_accuracy = evaluate(student_model)\n",
    "        \n",
    "    print(f\"Epoch {epoch + 1}, Loss: {running_loss / len(testloader)}, Accuracy: {student_accuracy * 100:.2f}%\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3c70e134",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.61 s ± 21.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%timeit evaluate(teacher_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "612d1f89",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.09 s ± 63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
     ]
    }
   ],
   "source": [
    "%timeit evaluate(student_model) # student model runs faster"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a01dc9c2",
   "metadata": {},
   "source": [
    "Read about 5 more techniques for model compression here: [Machine Learning Model Compression: A Critical Step Towards Efficient Deep Learning](https://www.dailydoseofds.com/model-compression-a-critical-step-towards-efficient-machine-learning)"
   ]
  }
 ],
 "metadata": {
  "finalized": {
   "timestamp": 1694847489806,
   "trusted": false
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
